from matplotlib import pyplot as plt
import seaborn as sns
from IPython import display
display.set_matplotlib_formats('svg')

plt.rcParams['font.sans-serif'] = ['simhei']

def plot_loss(data, path):
    print('plot loss.....')
    fig = plt.figure(figsize=(10, 7))

    plt.plot(list(range(len(data))), data)
    fig.savefig(path + 'loss.png', dpi=400)


def plot_scatter(data, center, label, path):
    print('plot scatter....')
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111)
    x, y = data.T[0], data.T[1]
    colors = ['red', 'green', 'blue']
    cluster = set(label)
    for batch, i in enumerate(cluster):
        index = label == i
        ax.scatter(x[index], y[index], c=colors[batch], label="cluater %d" % batch)
    cx, cy = center.T
    ax.scatter(cx, cy, marker='*', c="black", label="cluster center")
    plt.xlabel("DST", fontsize=16)
    plt.ylabel("PTDTC", fontsize=16)
    plt.legend(loc='best')
    fig.savefig(path + 'scatter.png', dpi=400)


def plot_pdf(data, xlabel, path, file):
    print('plot pdf %s'%xlabel)
    fig = plt.figure(figsize=(10, 7))
    sns.distplot(data, bins=100)
    plt.xlabel(xlabel)
    plt.savefig(path + file, dpi=400)
